#!/usr/bin/env python
"""Utilities for handling sets.

Exported functions
------------------
:py:func:`merge_sets`
    Recursively merge of sets based upon common members, until all groups
    sharing common members are merged.

:py:func:`get_random_sets`
    Generate random sets of members
"""
import copy


def merge_sets(list_of_sets, printer=None):
    """Merges sets in a list if they have one or more common member,
    until all groups sharing common members are merged.
    
    Merging takes places as follows:
    
        1.  All members in all sets are counted
    
        2.  All sets containing more than one member are merged via
            a call to :func:`_merge_sets`
            
        3.  A set of unmerged single-member sets is generated by finding all
            members that appear in multi-member sets after step (2), and subtracting
            this set from the set generated in step (1)
            
        4.  The list of single- and multi-member sets are concatenated to create
            a final list
    
    Repeat steps 1-4 until the number of sets is unchanged  
    
    Parameters
    ----------
    list_of_sets : list
        List of sets of hashable items

    printer : file-like, or `None`
        Writer to which status is sent, as string, if not `None` 
        (Default: `None`)

    Returns
    -------
    list
        list of merged sets
    """
    # first, simplify sets to remove redundancy
    list_of_sets = frozenset([frozenset(X) for X in list_of_sets])

    # count up all members
    all_members = []
    for my_set in list_of_sets:
        all_members.extend(my_set)

    all_members = set(all_members)

    # ignore singletons in merge process
    multis = [X for X in list_of_sets if len(X) > 1]
    if printer is not None:
        stmp = "Starting with %s sets, %s distinct members, and %s groups with multiple items ..." % (
            len(list_of_sets), len(all_members), len(multis)
        )
        printer.write(stmp)

    # find remaining singletons post-merge
    merged_multis = _merge_sets(multis, printer=printer)
    merged_members = []
    for my_set in merged_multis:
        merged_members.extend(my_set)

    merged_members = frozenset(merged_members)
    unmerged_members = all_members - merged_members
    unmerged_sets = [set([X]) for X in unmerged_members]

    return sorted([set(X) for X in merged_multis] + unmerged_sets)


def _merge_sets(list_of_sets, printer=None):
    """Inner loop function for :func:`merge_sets`.

    Parameters
    ----------
    list_of_sets : list
        List of multimember sets of hashable items to merge

    printer : file-like or `None`, optional
        Writer to which status is sent, as string, if not `None`
        (Default: :`None`)

    Returns
    -------
    list
        merged list of sets
    """
    if printer is not None:
        stmp = "Starting with %s sets..." % len(list_of_sets)
        printer.write(stmp)

    new_sets = []
    for set_a in list_of_sets:
        new_set = set(copy.deepcopy(set_a))
        for set_b in list_of_sets:
            if len(new_set & set_b) > 0:
                new_set |= set_b

        new_sets.append(frozenset(new_set))

    ltmp = list(set(new_sets))
    if printer is not None:
        stmp = "Merged %s starting sets to %s final sets ..." % (len(list_of_sets), len(ltmp))
        printer.write(stmp)

    # terminate if no new mergers
    if len(ltmp) == len(list_of_sets):
        return [set(X) for X in ltmp]
    # otherwise, recurse
    else:
        return _merge_sets(ltmp, printer=printer)


def get_random_sets(num_sets, max_len=3, members=list("abcdefghijklmnop")):
    """Generates random sets of members

    Parameters
    ----------
    num_sets : int
        number of sets to generate

    max_len : int
        maximum length for each set

    members : list-like
        Sequence of hashable members to put in each set
        (Default: letters a-p)

    Returns
    -------
    list
        List of generated sets
    """
    import numpy.random
    lout = []
    for _ in range(num_sets):
        num_in_set = numpy.random.randint(1, max_len + 1)
        stmp = set()
        for _ in range(num_in_set):
            stmp |= {members[numpy.random.randint(len(members))]}
        lout.append(stmp)
    return lout
